{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os, pdb, sys\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "sys.path.append('.')\n",
    "import torch\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "torch.backends.cudnn.deterministic = True\n",
    "import diffuser.utils as utils\n",
    "from diffuser.models import GaussianDiffusionPB\n",
    "from diffuser.guides.policies_compose import PolicyCompose\n",
    "from diffuser.utils.comp.comp_setup import ComposedParser\n",
    "from diffuser.utils.jupyter_utils import suppress_stdout\n",
    "from tap import Tap\n",
    "import argparse\n",
    "\n",
    "## You need to download 'rm2d-comp' from OneDrive Link in README.md to launch\n",
    "config = \"config/comp/Comp_rSmaze_nw6nw3_hExt0505gsd_engy_add.py\"\n",
    "class Parser(Tap):\n",
    "    config: str = config\n",
    "    plan_n_maze: int = 1\n",
    "    \n",
    "#---------------------------------- setup ----------------------------------#\n",
    "from diffuser.guides.comp_plan_env import ComposedEnvPlanner\n",
    "from diffuser.guides.comp_plan_helper import ComposedDiffusionPlanner\n",
    "sys.argv = ['PlaceHolder.py', '--config', config, ]\n",
    "\n",
    "with suppress_stdout():\n",
    "    args_comp = Parser().parse_args()\n",
    "    args_comp.cond_w = [2.0, 2.0]\n",
    "    cp = ComposedParser(args_comp.config, args_cmd=args_comp)\n",
    "    args_train_list, args_list = cp.setup_args_list(args_comp, )\n",
    "    dplanner = ComposedDiffusionPlanner(args_train_list, args_list, cp.comp_dataset)\n",
    "\n",
    "plan_env_list = dplanner.plan_env_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create an Env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffuser.utils.rm2d_render import RandStaticMazeRenderer\n",
    "import imageio\n",
    "from importlib import reload\n",
    "np.set_printoptions(precision=3)\n",
    "# import pb_diff_envs.environment.static.rand_maze2d_renderer as rmr; reload(rmr)\n",
    "from pb_diff_envs.environment.static.rand_maze2d_renderer import RandMaze2DRenderer\n",
    "rootdir = './visualization/'\n",
    "\n",
    "bs = 50 # batch size\n",
    "\n",
    "env_source = 'dataset' # or 'user_input'\n",
    "if env_source == 'user_input':\n",
    "    env_id = 0 # placeholder\n",
    "    wall_locations_c = [\n",
    "        np.array([[0.805, 3.83 ],\n",
    "                [0.852, 1.6  ],\n",
    "                [2.474, 2.591],\n",
    "                [2.585, 1.248],\n",
    "                [4.116, 3.365],\n",
    "                [4.128, 2.096],]),\n",
    "        np.array([\n",
    "                [0.813, 3.648],\n",
    "                [1.617, 2.137],\n",
    "                [2.947, 2.581],\n",
    "                [2.956, 0.718],\n",
    "                [4.347, 3.962]]),\n",
    "    ]\n",
    "    ## size of the obstacles 1x1 square block\n",
    "    if len(wall_locations_c) == 2:\n",
    "        hExt_list_c = [np.array([[0.5, 0.5],]*6), np.array([[0.5, 0.5],] * len(wall_locations_c[1]) )]\n",
    "    else:\n",
    "        hExt_list_c = [np.array([[0.5, 0.5],]*6), ]\n",
    "    env = plan_env_list.create_env_by_pos(env_id, wall_locations_c, hExt_list_c)\n",
    "\n",
    "elif env_source == 'dataset':\n",
    "    env_id = 1 # input an index\n",
    "    env = plan_env_list.create_single_env(env_id)\n",
    "\n",
    "\n",
    "obs_start_np = np.array( [ [ 0.22, 4.53 ] ], dtype=np.float32 ).repeat( bs, 0 )\n",
    "target_np = np.array( [ [3.93, 0.60] ], dtype=np.float32  ).repeat( bs, 0 )\n",
    "\n",
    "e_wlocs = env.wall_locations\n",
    "hExt = plan_env_list.hExt_list[0, 0].tolist() ## [0.5, 0.5]\n",
    "env_hExts = np.array([hExt,] * len(e_wlocs)) ## dupliate the wloc\n",
    "print(f'[Training Time] # of walls: 6')\n",
    "print(f'[Now] # of walls: {len(e_wlocs)}')\n",
    "\n",
    "mz_renderer = RandMaze2DRenderer(num_walls_c = list(map(lambda x: len(env.wloc_list_c[x]), list(range(len(env.wloc_list_c))) )), fig_dpi=200)\n",
    "mz_renderer.up_right = 'empty'\n",
    "mz_renderer.config['no_draw_traj_line'] = True\n",
    "mz_renderer.config['no_draw_col_info'] = True\n",
    "\n",
    "img = mz_renderer.renders( np.concatenate([obs_start_np[0:1,], target_np[0:1,],]), plan_env_list.maze_size, e_wlocs, env_hExts )\n",
    "utils.plt_img(img, dpi=80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.wall_locations.shape ## (n_walls, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First Run env creation several blocks below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import diffuser.guides.comp_plan_env as dcpe; reload(dcpe)\n",
    "from diffuser.guides.comp_plan_env import ComposedEnvPlanner, ComposedRM2DReplanner\n",
    "\n",
    "rm2d_planner = ComposedEnvPlanner()\n",
    "policy: PolicyCompose = dplanner.policy\n",
    "policy.ddim_eta = 1.0\n",
    "use_ddim = True\n",
    "repl_dn_steps = 5\n",
    "policy.return_diffusion = True\n",
    "print(f'policy guidance weight: {policy.cg_w}')\n",
    "print(f'policy ddim_steps: {policy.ddim_steps}')\n",
    "\n",
    "wloc_np = env.wall_locations ## (num_walls, 2)\n",
    "repl_config = dict(use_ddim=use_ddim, dn_steps=repl_dn_steps)\n",
    "replanner = ComposedRM2DReplanner(policy, dplanner.train_normalizer_list[0], repl_config, 'cuda')\n",
    "\n",
    "obs_start_tensor = utils.to_torch(obs_start_np,  device='cpu')\n",
    "target_tensor = utils.to_torch( target_np, device='cpu' )\n",
    "wloc_input = wloc_np[None, ].repeat(bs, axis=0) # (10, 2) -> (bs, 10, 2)\n",
    "wloc_input = wloc_input.reshape(bs, -1)\n",
    "wall_locations = utils.to_torch( wloc_input,  device='cpu' )\n",
    "\n",
    "use_normed_wallLoc = False\n",
    "plan_env_config = utils.Config(None,\n",
    "    env_mazelist=dplanner.plan_env_list,\n",
    "    prob_permaze=1,\n",
    "    samples_perprob=bs,\n",
    "    obs_selected_dim=(0,1),\n",
    "    horizon_pl_list=[40, 48, 52, 56, 60, 64],\n",
    "    horizon_wtrajs=48,\n",
    "    do_replan=False,\n",
    "    replanner=replanner,\n",
    ")\n",
    "with suppress_stdout():\n",
    "    ## return motion candidates of all horizon\n",
    "    ## the motion plans: samples_mulho[i].observations (trajectories of states)\n",
    "    _, samples_mulho, _ = rm2d_planner.plan_env_interact(\n",
    "        env, policy, obs_start_tensor, target_tensor, wall_locations, use_normed_wallLoc, plan_env_config,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffuser.utils.jupyter_utils import get_all_suc_trajs\n",
    "all_candidates = [] ## candidates of several horizons\n",
    "for trajs in samples_mulho:\n",
    "    all_candidates.extend(trajs.observations)\n",
    "\n",
    "## pick out successful trajectories\n",
    "print('len all_candidates:', len(all_candidates))\n",
    "env.robot.collision_eps = 1e-2\n",
    "suc_trajs, suc_idxs = get_all_suc_trajs(env, trajs=all_candidates, goal=target_np)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Render Motion Plans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Rendering\n",
    "from diffuser.utils.rm2d_render import RandStaticMazeRenderer\n",
    "\n",
    "rd = RandStaticMazeRenderer(plan_env_list)\n",
    "\n",
    "tmp_path = f'{rootdir}/luotest.png'\n",
    "renderer = plan_env_list.model_list[env_id].renderer\n",
    "renderer.fig_dpi = 40 # lower resolution for fast render\n",
    "renderer.up_right = 'default'\n",
    "renderer.config['no_draw_traj_line'] = False\n",
    "renderer.config['no_draw_col_info'] = False\n",
    "\n",
    "rd.composite( tmp_path, suc_trajs[:10],  np.array([env_id])) # visualize suc trajs\n",
    "img = imageio.imread(tmp_path)\n",
    "utils.plt_img(img, dpi=200) # 600"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### create a denoising video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "traj_idx = suc_idxs[0] # which path\n",
    "hor_idx = traj_idx // bs\n",
    "path_idx = traj_idx % bs # which path\n",
    "\n",
    "## multi_hor_x0_perstep: stores all the x0 in the denoising process\n",
    "n_steps = len(policy.multi_hor_x0_perstep[0]) ## number of denosing steps\n",
    "trajs_ps = []\n",
    "imgs_ps = []\n",
    "\n",
    "mz_renderer = RandMaze2DRenderer(num_walls_c = list(map(lambda x: len(env.wloc_list_c[x]), list(range(len(env.wloc_list_c))) )), fig_dpi=200)\n",
    "mz_renderer.up_right = 'default'\n",
    "mz_renderer.config['no_draw_traj_line'] = False\n",
    "mz_renderer.config['no_draw_col_info'] = False\n",
    "\n",
    "for i_st in range(n_steps):\n",
    "    traj_ps = policy.multi_hor_x0_perstep[hor_idx][i_st].observations[path_idx]\n",
    "    trajs_ps.append(traj_ps)\n",
    "    ### -----\n",
    "    e_wlocs = env.wall_locations\n",
    "    hExt = plan_env_list.hExt_list[0, 0].tolist()\n",
    "    env_hExts = np.concatenate(env.hExt_list_c, axis=0)\n",
    "    img = mz_renderer.renders( traj_ps , plan_env_list.maze_size, e_wlocs, env_hExts )\n",
    "    imgs_ps.append(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('rootdir:', rootdir)\n",
    "import diffuser.utils.video as duv; reload(duv)\n",
    "from diffuser.utils.video import save_images_to_mp4, read_mp4_to_numpy\n",
    "\n",
    "mp4_name = f'{rootdir}/comp_maze2d_env{env_id}.mp4' # {utils.get_time()}\n",
    "## Save to video\n",
    "save_images_to_mp4(imgs_ps, mp4_name, fps=15, st_sec=1, end_sec=1)\n",
    "## load and save\n",
    "import mediapy as media\n",
    "from mediapy import show_video, read_video\n",
    "show_video(read_mp4_to_numpy(mp4_name), fps=15, width=400)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
